# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# linear head implementation for DUST3R
# --------------------------------------------------------
import torch.nn as nn
import torch.nn.functional as F
from dust3r.heads.postprocess import postprocess
import numpy as np
import torch
import math
import warnings

class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2)
        self.norm = nn.BatchNorm2d(out_channels)
        self.A = nn.SELU()

    def forward(self, x):
        return self.A(self.norm(self.conv(x)))

class MiddleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, activation = False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2)
        self.norm = nn.BatchNorm2d(out_channels)
        self.A = nn.SELU()
        self.activation = activation
    def forward(self, x):
        
        if self.activation:
            return self.A(self.norm(self.conv(x)))
        else:
            return self.norm(self.conv(x))

class LinearPts3d(nn.Module):
    """ 
    Linear head for dust3r
    Each token outputs: - 16x16 3D points (+ confidence)
    """

    def __init__(self, net, has_conf=False, skip = False):
        super().__init__()
        self.patch_size = net.patch_embed.patch_size[0]
        self.depth_mode = net.depth_mode
        self.conf_mode = net.conf_mode
        self.has_conf = has_conf
        self.skip = skip

        self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
        self.skip_proj = None
        if self.skip:
            skip_dim = 256
            self.skip_proj_scale = 1e-4
            self.skip_proj = nn.Sequential(
                MiddleBlock(6, skip_dim, activation=True),
                MiddleBlock(skip_dim, skip_dim, kernel_size=5, activation=True),
                MiddleBlock(skip_dim, skip_dim, kernel_size=5, activation=True),
                MiddleBlock(skip_dim, 3, activation=False),
            )
            # zero init the last skip proj
            for m in self.skip_proj[-1].modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.zeros_(m.weight)
                    nn.init.zeros_(m.bias)

    def forward(self, decout, img_shape):
        H, W = img_shape
        img = None
        if type(decout) is tuple:
            decout, img = decout
        tokens = decout[-1]
        B, S, D = tokens.shape

        # extract 3D

    def setup(self, croconet):
        pass

    def forward(self, decout, img_shape):
        H, W = img_shape
        img = None
        if type(decout) is tuple:
            decout, img = decout
        tokens = decout[-1] # choose the output of the last layer
        B, S, D = tokens.shape

        # extract 3D points
        feat = self.proj(tokens)  # B,S,D
        feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) # BSD-> B,D,S -> B,D,S_sqrt,S_sqrt (B,4,300) (B, 20, 20, 3)
        feat = F.pixel_shuffle(feat, self.patch_size)  # B,4,H,W

        if self.skip:
            skip_input = torch.cat([img, feat[:, :3]], dim=1)
            img_skip = self.skip_proj(skip_input) * self.skip_proj_scale
            feat[:, :3] = feat[:, :3] + img_skip

        # permute + norm depth
        return postprocess(feat, self.depth_mode, self.conf_mode)



def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module


def build_pytorch_mlp(input_dim, hidden_dim, output_dim, depth=10, bias=False):
    mlp = []
    mlp.append(nn.Linear(input_dim, hidden_dim, bias=bias))
    mlp.append(nn.ReLU())
    for _ in range(depth - 1):
        mlp.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
        mlp.append(nn.ReLU())
    mlp.append(nn.Linear(hidden_dim, output_dim, bias=bias))
    mlp = nn.Sequential(*mlp)
    return mlp


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.",
            stacklevel=2,
        )

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class DownSamplings(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=2):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_channels_i = in_channels if i == 0 else out_channels // (2 ** (num_layers - i))
            out_channels_i = out_channels // (2 ** (num_layers - 1 - i))
            self.layers.append(DownSampling(in_channels_i, out_channels_i))
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class GSHead(nn.Module):
    def __init__(
        self,
        net,
        mlp_dim=768, # 256 original
        mlp_depth=1, # 1 means one hidden layer, 2-mlp
        # scale_range=(0.0001, 0.02),
        cp=True,
        radius=0.5, # original 0.5 too
        norm_layer=nn.LayerNorm,
        skip = False,
        sh_degree = 0,
    ):
        super().__init__()

        self.sh_degree = sh_degree
        self.sh_base = (sh_degree + 1) ** 2

        patch_size = net.patch_embed.patch_size[0] # 16
        self.patch_size = patch_size

        token_dim = net.dec_embed_dim
        self.token_dim = token_dim
        # self.scale_range = scale_range
        self.cp = cp
        self.radius = radius
        self.skip = skip

        self.norm = norm_layer(token_dim)

        patch_size_sqr = patch_size * patch_size

        self.mlp_rgb = build_pytorch_mlp(
            token_dim, mlp_dim, patch_size_sqr * 3 * self.sh_base, depth=mlp_depth, bias=False
        )
        self.mlp_opacity = build_pytorch_mlp(
            token_dim, mlp_dim, patch_size_sqr * 1, depth=mlp_depth, bias=False
        )
        self.mlp_scale2 = build_pytorch_mlp(
            token_dim, mlp_dim, patch_size_sqr * 3, depth=mlp_depth, bias=False # to avoid loading this head
        )
        self.mlp_rotation = build_pytorch_mlp(
            token_dim, mlp_dim, patch_size_sqr * 4, depth=mlp_depth, bias=False
        )
        
        self.cnn_before = None
        self.cnn_after = None

        if self.skip:
            self.cnn_before = DownSamplings(3, mlp_dim, num_layers=4)
            self.cnn_after = nn.Sequential(MiddleBlock(3, mlp_dim, activation=True), MiddleBlock(mlp_dim, 3 * self.sh_base, activation=False))

        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, decout, img_shape):
        decout, img = decout # img: [B, 3, H, W]
        tokens = decout[-1]
        H, W = img_shape
        B, S, D = tokens.shape # S is token_num (14 * 14)
        tokens = self.norm(tokens)
        if self.skip:
            img_before_output = self.cnn_before(img)
            # print('before', img.shape, img_before_output.shape, tokens.shape) # before torch.Size([4, 3, 224, 224]) torch.Size([4, 768, 14, 14]) torch.Size([4, 196, 768])
            img_before_output = img_before_output.flatten(-2, -1).permute(0, 2, 1) # B, S, D
            tokens = tokens + img_before_output

        rgb = self.mlp_rgb(tokens) # [bs, n_token, dim]
        opacity = self.mlp_opacity(tokens)
        scale = self.mlp_scale2(tokens)
        rotation = self.mlp_rotation(tokens)

        # if not self.skip:
        #     rgb = torch.sigmoid(rgb) * 2 - 1

        opacity = torch.sigmoid(opacity)

        # scale = self.scale_range[0] + torch.sigmoid(scale) * (
        #     self.scale_range[1] - self.scale_range[0]
        # )
        scale = scale * 0.1

        output_combine = [rgb, opacity, scale, rotation]
        output_combine_pixel = []
        for feat in output_combine:
            feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) # B,D=d*16*16,H//16,W//16
            feat = F.pixel_shuffle(feat, self.patch_size)  # B,d,H,W
            feat = feat.permute(0,2,3,1) # B,H,W,d
            output_combine_pixel.append(feat)
        
        rgb, opacity, scale, rotation = output_combine_pixel
        if self.skip:
            img_after_output = self.cnn_after(img)
            print('after', img_after_output.shape, rgb.shape)
            img_after_output = img_after_output.permute(0,2,3,1) # B,H,W,d
            # rgb = torch.sigmoid(rgb + img_after_output) * 2 - 1
            rgb = rgb + img_after_output
        
        rotation = nn.functional.normalize(rotation, dim=-1, eps=1e-5)

        out = {}
        out["rgb"] = rgb
        out["opacity"] = opacity
        out["scale"] = scale
        out["rotation"] = rotation

        # if torch.is_nan(out["rgb"]).any():  

        return out
